# -*- coding: utf-8 -*-
"""
Created on Thu Apr 16 11:23:57 2020

@author: xiang_yaobing
"""

# -*- coding: utf-8 -*-
"""
Created on Wed Apr  16 19:24:30 2019

@author: xiang_yaobing
内容：
1.fits可视化代码段
2.cvs 可视化代码段
version = 1.2
"""

#use for read fits and trans it's form
from astropy.io import fits
import numpy as np
import pandas as pd
from numpy import nan as nan
from pandas import Series, DataFrame
import data_genrate_1 as dg
import os 
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt
import cv2
def mkdir(path):

	folder = os.path.exists(path)

	if not folder:                   #判断是否存在文件夹如果不存在则创建为文件夹
		os.makedirs(path)            #makedirs 创建文件时如果路径不存在会创建这个路径
		print ("---  new folder...  ---")
		print ("---  OK  ---")

	else:
		print ("---  There is this folder(it had already been build)!  ---")
        
def get_data(dat):
    '''
    从fits中得到所需数据，写到指定矩阵
    '''
    data_arr = np.zeros([dat.size, 5])
    for i in range (dat.size):
        data_arr[i][0] = dat[i][0]
        data_arr[i][1] = dat[i][1] 
#        data_arr[i][2] = dat[i][5]#plx
#        data_arr[i][2] = dat[i][10]#pmRA
#        data_arr[i][3] = dat[i][12]#pmDE
    return data_arr
'''        
data=fits.open('C:\\Users\\xiang_yaobing\\Desktop\\天文\\cluster_data\\正负样本\\cluster\\berkeley19.r0.1.fits')#查看单个数据
dat = data[1].data
#data_arr=np.zeros([dat.size, 5])
data_arr = get_data(dat)
data_deal = DataFrame(data_arr)
#data_deal.dropna(axis = 0, how = 'any')  
df4 = data_deal.dropna() # to dropout the line include nan
data_deal = np.array(df4)

#np.savetxt('../data_center/q.txt', data_deal)
'''

def data_wash(data_deal, fa):#对数据进行高斯估计，去除个别噪声,返回处理后数据及数据参数列表
    '''
    loc_value: raj_max, raj_min, dej_max, dej_min
    pm_value: pm_mean +- pmra_std, pm_mean +- pmde_std
    八个变量是控制视图显示范围
    '''
    raj_max = np.max(data_deal[:,0])
    raj_min = np.min(data_deal[:,0])
    dej_max = np.max(data_deal[:,1])
    dej_min = np.min(data_deal[:,1])
    loc_value = [raj_min, raj_max, dej_min, dej_max]
#    pmra_std = np.std(data_deal[:,2],ddof = 1)
#    pmde_std = np.std(data_deal[:,3],ddof = 1)
#    pm_mean = [np.mean(data_deal[:,2]), np.mean(data_deal[:,3])]
    
#    rate = 1 
#    ##中误差控制率
#    pmra_max_value = pm_mean[0]+rate*pmra_std
#    pmra_min_value = pm_mean[0]-rate*pmra_std
#    pmde_max_value = pm_mean[1]+rate*pmde_std
#    pmde_min_value = pm_mean[1]-rate*pmde_std
    data_deal = data_deal[data_deal[:,0] < raj_max-fa]
    data_deal = data_deal[data_deal[:,0] > raj_min+fa]
    data_deal = data_deal[data_deal[:,1] < dej_max-fa]
    data_deal = data_deal[data_deal[:,1] > dej_min+fa]
    loc_value_s = [raj_min+fa, raj_max-fa, dej_min+fa, dej_max-fa]
#    pm_value = [pmra_min_value, pmra_max_value, pmde_min_value, pmde_max_value]
    return data_deal, loc_value, loc_value_s
#data_deal_wish, pm_value, loc_value = data_wash(data_deal)
#dg.cluster_subplot(data_deal_wish,loc_value, pm_value)

def save_the_data_subpicture(file_path, save_path):#读取某一文件夹数据集，保存数据多维图片到文件夹
    filelist_open = [os.path.join(file_path, f) for f in os.listdir(file_path)]
    for fit in filelist_open:
        print(fit)
        data=fits.open(fit)    
        dat = data[1].data      
        data_arr = get_data(dat)
        data_deal = DataFrame(data_arr)  
        df4 = data_deal.dropna() # to dropout the line include nan
        data_deal = np.array(df4)
        data_deal_wish, pm_value, loc_value = data_wash(data_deal)
        dg.cluster_subplot(data_deal_wish,loc_value, pm_value)
        fit_name = fit.split('\\')
        plt.savefig(save_path+fit_name[-1] +'.jpg')
    return 0
        

def save_fits_data_realpicture(file_path, save_path):#读取某一文件夹数据集，保存数据多维机器用图片到文件夹
    '''
    生成机器训练与判断数据
    '''
    filelist_open = [os.path.join(file_path, f) for f in os.listdir(file_path)]
    file_name = (file_path.split('\\'))[-2]
    #print (file_name)
    new_loc_path = save_path + file_name + '\\loc\\'
    new_vel_path = save_path + file_name + '\\vel\\'
    mkdir(new_loc_path)
    mkdir(new_vel_path)
    for fit in filelist_open:
        print(fit)
        data=fits.open(fit)    
        dat = data[1].data      
        data_arr = get_data(dat)
        data_deal = DataFrame(data_arr)  
        df4 = data_deal.dropna() # to dropout the line include nan
        data_deal = np.array(df4)
        data_deal_wish, pm_value, loc_value = data_wash(data_deal)
        fit_name = fit.split('\\')
        loc_save_path = new_loc_path + fit_name[-1] + '.jpg'
        vel_save_path = new_vel_path + fit_name[-1] + '.jpg'
        dg.cluster_loc_plotsave(loc_save_path, data_deal_wish, loc_value)
        dg.cluster_vel_plotsave(vel_save_path, data_deal_wish, pm_value)
    return 0

def deal_csv_array(file_path):
    #处理csv文件返回矩阵
    data = pd.read_csv(file_path)
    info = data.loc[:,['ra','dec']]
    dat = info.dropna()
    array = np.array(dat)
    return array

def gauss_kde_location(mix_stars, loc_value):#实现位置场的核密度估计
    loc = np.array([mix_stars[:,0],mix_stars[:,1]])
    kde = gaussian_kde(loc)
    xgrid = np.linspace(loc_value[0], loc_value[1], 60)
    ygrid = np.linspace(loc_value[2], loc_value[3], 60)
    xgrid,ygrid = np.meshgrid(xgrid,ygrid)
    z = kde.evaluate(np.vstack([xgrid.ravel(),ygrid.ravel()]))
    return z,xgrid

def save_csv_data_realpicture(file_path, save_path):#读取某一文件夹数据集，保存数据多维机器用图片到文件夹
    '''
    生成机器训练与判断数据
    '''
    filelist_open = [os.path.join(file_path, f) for f in os.listdir(file_path)]
    file_name = (file_path.split('\\'))[-2]
    #print (file_name)
    new_loc_path = save_path + file_name + '\\loc\\'
    new_vel_path = save_path + file_name + '\\loc0.1\\'
    mkdir(new_loc_path)
    mkdir(new_vel_path)
    for csv in filelist_open:
        print(csv)
        data_deal = deal_csv_array(csv)
        data_deal_wish, loc_value, loc_value_s = data_wash(data_deal)
        csv_name = csv.split('\\')
        loc_save_path = new_loc_path + csv_name[-1] + '_loc.jpg'
        locs_save_path = new_vel_path + csv_name[-1] + '_loc0.1.jpg'
        dg.cluster_loc_plotsave(loc_save_path, data_deal, loc_value)
        dg.cluster_loc_plotsave(locs_save_path, data_deal_wish, loc_value_s)
    return 0


#file_path = 'E:\\天文\\cluster_data\\正负样本\\cluster_mini\\cluster\\'
#save_path = 'E:\\天文\\cluster_data\\正负样本\\machine_use_picture_fa1\\'
#save_the_data_subpicture(file_path, save_path)
file_path = 'F:\\疏散星团检测\\cluster dataset\\train_csv_dataset\\csv_dataset_no_cluster\\'
save_path = '..\\trainset\\noClusterKDEpicture_multiKde\\'
# save_csv_data_realpicture(file_path, save_path)
if __name__ == '__main__':
    filelist_open = [os.path.join(file_path, f) for f in os.listdir(file_path)]
    file_name = (file_path.split('\\'))[-2]
    #print (file_name)
    new_loc_path = save_path + file_name + '\\loc\\'
    new_vel_path = save_path + file_name + '\\loc0.1\\'
    mkdir(new_loc_path)
    mkdir(new_vel_path)
    for csv in filelist_open:
        print(csv)
        data= deal_csv_array(csv)
        datas, loc_value, loc_value_s = data_wash(data, 0.05)
        datass, loc_value, loc_value_ss = data_wash(data, 0.1)
        csv_name = csv.split('\\')
        loc_save_path = new_loc_path + csv_name[-1] + '_loc.jpg'
        locs_save_path = new_vel_path + csv_name[-1] + '_loc0.1.tif'
        
        dataList = np.zeros((60,60,3))
        z, grid = gauss_kde_location(data, loc_value)
        z = z.reshape(grid.shape)
        z = ((z-z.min())/(z.max()-z.min()))*255
        z=255-z.astype(np.int)
        #z = np.reshape(z,(grid.shape[0],grid.shape[1],1))
        dataList[:,:,0] = z
        zs, grids = gauss_kde_location(datas, loc_value_s)
        zs = zs.reshape(grid.shape)
        zs = ((zs-zs.min())/(zs.max()-zs.min()))*255
        zs=255-zs.astype(np.int)
        #zs = np.reshape(zs,(grid.shape[0],grid.shape[1],1))
        dataList[:,:,1] = zs
        zss, grids = gauss_kde_location(datass, loc_value_ss)
        zss = zss.reshape(grid.shape)
        zss = ((zss-zss.min())/(zss.max()-zss.min()))*255
        zss=255-zss.astype(np.int)
        #zss = np.reshape(zss,(grid.shape[0],grid.shape[1],1))
        #multiKde = np.reshape(np.array([[z],[zs],[zss]]),(grid.shape[0],grid.shape[1],3))
        dataList[:,:,2] = zss
        multiKde = dataList
        cv2.imwrite(loc_save_path, multiKde)
        #cv2.imwrite(locs_save_path, zs)

